import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import torchsde
import random
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import argparse

# Example: Assume 'data' is a numpy array of shape (100, 50, 253)
# 100 samples, 50 assets, 253 time points (252 historical + 1 today)
class GBM(nn.Module):
    def __init__(self, input_dim, mu, sigma):
        super(GBM, self).__init__()
        self.mu = nn.Parameter(torch.tensor(mu, dtype=torch.float32))
        self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float32))
    # Drift function
    def f(self, t, S):
        return self.mu * S
    # Diffusion function
    def g(self, t, S):
        return self.sigma * S

# Define the SDE class with required noise_type and sde_type attributes
class SDE(nn.Module):
    noise_type = "diagonal"  # This specifies the type of noise
    sde_type = "ito"  # This specifies the type of SDE (Itô or Stratonovich)
    def __init__(self, input_dim, mu, sigma):
        super(SDE, self).__init__()
        self.drift = GBM(input_dim, mu, sigma)
    def f(self, t, S):
        return self.drift.f(t, S)
    def g(self, t, S):
        return self.drift.g(t, S)
    
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# Create a PyTorch Dataset
class TimeSeriesDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Define the LSTM model
class LSTMModel(nn.Module):
    def __init__(self,  num_layers=2, hidden_size=50, input_size=1, output_size=1):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out)
        return out

# Prediction from initial condition
def predict_full_series(model, initial_condition, time_steps=252):
    model.eval()
    predictions = []
    input_seq = torch.tensor(initial_condition, dtype=torch.float32).view(1, 1, 1).to(next(model.parameters()).device)
    h = torch.zeros(model.num_layers, 1, model.hidden_size).to(next(model.parameters()).device)
    c = torch.zeros(model.num_layers, 1, model.hidden_size).to(next(model.parameters()).device)
    with torch.no_grad():
        for _ in range(time_steps):
            out, (h, c) = model.lstm(input_seq, (h, c))
            pred = model.fc(out[:, -1, :])
            predictions.append(pred.item())
            input_seq = pred.view(1, 1, 1)
    return predictions

# Prepare sequences for LSTM
def create_sequences(data, time_steps=252):
    X, y = [], []
    for i in range(data.shape[0]):
        for j in range(data.shape[1] - time_steps):
            X.append(data[i, j:j+time_steps])
            y.append(data[i, j+1:j+time_steps+1])
    return np.array(X), np.array(y)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='LSTM_ASSET_PRICE_DYN')
    ### NN HYPERPARAMS
    parser.add_argument('--seed', type=int, default=1001, help='random seed')
    parser.add_argument('--nHiddenUnit', type=int, default = 50, help='number of hidden units')
    parser.add_argument('--activation', type=str, default = "RELU", help='activation_function')
    parser.add_argument('--optimizer', type=int, default = 2, help='GD algorithm')
    parser.add_argument('--lr', type=float, default = 1e-3, help='total number of datapoints')
    parser.add_argument('--batchsize', type=int, default = 10, help='training batch size')
    parser.add_argument('--train_test_and_valid_split', type=float, default = .2)
    parser.add_argument('--normalize', type=bool, default = False)
    parser.add_argument('--max_epochs', type=int, default = 700, help='max training epochs')
    parser.add_argument('--lambdA', type=float, default = .5)

    parser.add_argument('--mu_lb', type=float, default = .1)
    parser.add_argument('--mu_ub', type=float, default = 1)
    parser.add_argument('--sigma_lb', type=float, default = .01)
    parser.add_argument('--sigma_ub', type=float, default = .5)

    parser.add_argument('--max_patience', type=int, default = 5)
    parser.add_argument('--nSamples', type=int, default = 10000, help='number of layers')
    parser.add_argument('--nLayer', type=int, default = 2, help='number of layers')
    parser.add_argument('--id', type=int, default = 5, help='number of layers')

    args = parser.parse_args()
    args = vars(args) # change to dictionary

    set_seed(args['seed'])
    nex = 30

    num_assets = 50
    init_X_train = np.load(f'portfolio_data/init_asset_prices_training_8000.npy').T
    init_X_valid = np.load(f'portfolio_data/init_asset_prices_validation_1000.npy').T
    init_X_test= np.load(f'portfolio_data/init_asset_prices_test_1000.npy').T
    T = 1.0
    time_steps = 252
    dt = 1/252
    num_paths = nex
    time_points = torch.linspace(0, T, int(T/dt) + 1)
    mu = np.random.uniform(0.5, 1)
    sigma = np.random.uniform(0.05, 0.1)
    # Convert parameters to tensors
    mu_torch = torch.tensor(mu, dtype=torch.float32)
    sigma_torch = torch.tensor(sigma, dtype=torch.float32)
    #S0_torch = torch.tensor(S0, dtype=torch.float32).unsqueeze(0) ### Steady State asset price
    #S0_extended_torch = torch.stack([mu_torch, sigma_torch, S0_torch.squeeze(0)], dim=1)
    # Initialize the model
    sde_model = SDE(num_assets, mu_torch, sigma_torch)
    # Simulate asset dynamics : the idea is that, given a range of s(0), the price will vary following 
    # based on some SDE dynamics, with fixed parameter, but a stochastic behavior.
    # The training aims to instruct the model to predict a future price trend.

    price_paths = []
    price_paths_extended = []
    x = []
    index = 0

    for index in range(num_paths):
        if index < int(nex*.8):
            S0 = init_X_train[index,:] #np.random.uniform(0, 1, num_assets)
        elif index >= int(nex*.8) and index < int(nex*.9):
            S0 = init_X_valid[index-int(nex*.8),:]
        else:
            S0 = init_X_test[index-int(nex*.9),:]
        S0_torch = torch.tensor(S0, dtype=torch.float32).unsqueeze(0) ### Steady State asset price
        # mu = np.random.uniform(0.01, 0.1, num_assets)
        # sigma = np.random.uniform(0.1, 0.3, num_assets)
        # mu_torch = torch.tensor(mu, dtype=torch.float32)
        # sigma_torch = torch.tensor(sigma, dtype=torch.float32)
        sde_model = SDE(num_assets, mu_torch, sigma_torch)
        x.append(torch.squeeze(S0_torch))
        # S0_extended_torch = torch.stack([S0_torch.squeeze(0), mu_torch, sigma_torch], dim=1)
        # init_cond_extended.append(S0_extended_torch.squeeze(0))
        #set_seed(args['seed'])
        S_paths = torch.squeeze(torchsde.sdeint(sde_model, S0_torch, time_points, method='euler')) #.squeeze(0).squeeze(0)
        mu_path = mu_torch.repeat(S_paths.size(0))
        sigma_path = sigma_torch.repeat(S_paths.size(0))
        price_paths.append(S_paths.detach().numpy()) 
        # price_paths_extended.append(torch.stack([S_paths, mu_path, sigma_path], dim=1).detach().numpy())
        #print(S_paths.squeeze(0).detach().numpy()[:10])
        #print(index)
        index += 1
        tmp = S_paths.squeeze(0).detach().numpy()
        print(index)
        # plt.plot(time_points.numpy(),tmp[:,np.random.randint(1,tmp.shape[1])])
        # plt.show()
    # price_paths_train = np.load(f'portfolio_data/init_asset_prices_training_8000.npy').T
    # price_paths_valid = np.load(f'portfolio_data/init_asset_prices_validation_1000.npy').T
    # price_paths_test = np.load(f'portfolio_data/init_asset_prices_test_1000.npy').T

    data = np.stack(price_paths, axis=-1)
    data = np.swapaxes(price_paths,2,1)

    num_samples, num_assets, time_points = data.shape

    # # Normalize the data for each asset
    # scalers = [MinMaxScaler(feature_range=(0, 1)) for _ in range(num_assets)]
    # data_normalized = np.zeros(data.shape)
    # for asset_idx in range(num_assets):
    #     asset_data = data[:, asset_idx, :]
    #     data_normalized[:, asset_idx, :] = scalers[asset_idx].fit_transform(asset_data)

    data_normalized = data

    # Combine the data from all assets into a single dataset for training
    combined_data = data_normalized.reshape(-1, time_points)  # Shape: (num_samples*num_assets, time_points)

    X, y = create_sequences(combined_data, time_steps)

    # Check if X and y are not empty
    print(f"X shape: {X.shape}, y shape: {y.shape}")

    # Convert to PyTorch tensors
    X = torch.tensor(X, dtype=torch.float32).unsqueeze(-1)  # Add feature dimension
    y = torch.tensor(y, dtype=torch.float32).unsqueeze(-1)

    # Create DataLoader
    batch_size = args['batchsize']

    X_train, X_valid , Y_train, Y_valid = train_test_split(X, y , test_size=0.1, random_state=1)
    X_valid, X_test , Y_valid, Y_test = train_test_split(X_valid, Y_valid, test_size=0.5) #, random_state=1)

    #X_train, X_test , Y_train, Y_test = train_test_split(init_cond_extended, price_paths_extended_torch , test_size=0.2, random_state=1)

    train_data = TimeSeriesDataset(X_train, Y_train)   # X:(1024,2) Y:(1024)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    valid_data = TimeSeriesDataset(X_valid, Y_valid)   # X:(1024,2) Y:(1024)
    valid_loader = DataLoader(valid_data, batch_size=len(valid_data), shuffle=False)

    test_data = TimeSeriesDataset(X_test, Y_test)   # X:(1024,2) Y:(1024)
    test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False)

    #dataset = TimeSeriesDataset(X, y)
    #dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Instantiate the model, define the loss function and the optimizer
    model = LSTMModel(args['nLayer'], args['nHiddenUnit'])
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=args['lr'])

    # Train the model
    num_epochs, min_loss = 10,10000 #args['max_epochs'], 10000
    patience, max_patience = 0, args['max_patience']
    lambdA = args['lambdA']

    model.train()
    for epoch in range(num_epochs):
        for batch_X, batch_y in train_loader:
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y) + lambdA * criterion(outputs[:,-1,:], batch_y[:,-1,:])
            print(lambdA * criterion(outputs[:,-1,:], batch_y[:,-1,:]))
            print(criterion(outputs, batch_y))
            loss.backward()
            optimizer.step()
        if (epoch+1) % 1 == 0:
            print(f'epoch [{epoch+1}/{num_epochs}], training loss: {loss.item():.4f}')
        for batch_X, batch_y in test_loader:
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs[:,-1,:], batch_y[:,-1,:])
        if (epoch+1) % 1 == 0:
            print(f'epoch [{epoch+1}/{num_epochs}], test loss: {loss.item():.4f}')
        if loss<min_loss:
            torch.save(model.state_dict(), f"best_LSTM_model/model_{args['id']}.pt")
            min_loss = loss
            patience = 0
        else:
            patience += 1
            if patience>= max_patience:
                break

    model.load_state_dict(torch.load(f"best_LSTM_model/model_{args['id']}.pt"))
    for batch_X, batch_y in valid_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs[:,-1,:], batch_y[:,-1,:])
    print(f'epoch [{epoch+1}/{num_epochs}], validation loss: {loss.item():.4f}')

    # Predict the entire time series for each asset from the initial condition
    predictions_all_assets = []
    for asset_idx in range(num_assets):
        initial_condition = data_normalized[-1, asset_idx, 0]  # Use the first time point as the initial condition
        predicted_series_normalized = predict_full_series(model, initial_condition, time_steps=252)
        tmp = np.expand_dims(predicted_series_normalized, axis=0).T
        #predicted_series = scalers[asset_idx].inverse_transform(np.array(tmp))#.reshape(-1, 1)).reshape(-1)
        predictions_all_assets.append(predicted_series_normalized)

    # Plot the results for the first asset
    asset_idx = 0
    plt.figure(figsize=(12, 6))
    plt.plot(np.arange(252), data[0, asset_idx, :252], label='Historical Prices')
    plt.plot(np.arange(252, 253), data[0, asset_idx, 252:], label='Actual Future Prices', color='green')
    plt.plot(np.arange(1, 253), predictions_all_assets[asset_idx], label='Predicted Future Prices', color='red')
    plt.title(f'Prediction for Asset {asset_idx + 1}')
    plt.legend()
    plt.show()

    record = {
        'id' : [args['id']],
        'MSE_valid' : [loss.detach().numpy()],
        'MSE_test' : [min_loss.detach().numpy()]
        }
        
    torch.save(model.state_dict(), f"best_NSDE_model/model_{args['id']}.pt")
    df = pd.DataFrame(record)
    df.to_csv('LSTM_dyn_portfolio_results.csv',mode='a', header=False, index=False)